function [Phi_NxYz,grid_Y,varargout] = MakeSparseTransMatrix_NxYz(parms,gesol,rho_Y,F_Nxzz,NY,prob_Nxz)

    if abs(rho_Y)>=1; error('should have |rho_Y|<1'); end
    if NY<=1; error('should have NY>1'); end
    
    size_F = size(F_Nxzz);
    if numel(size_F)~=4; error('F_NLzz has the wrong dimension'); end

    NN = parms.NN; Nx = parms.Nx; Nz = parms.Nz;
    if size_F(1)~=NN || size_F(2)~=Nx || size_F(3)~=Nz || size_F(4)~=Nz
        error('F_Nxzz has the wrong dimension');
    end

    % 1. use prob_Nxz to figure out the range for Y
    temp_F = sum(permute(repmat(parms.StateTransitionProbs,[1 1 NN Nx]),[3 4 1 2]).*F_Nxzz,4);
    mean_F = sum(prob_Nxz.*temp_F,'all');
    std_F = sqrt(sum(prob_Nxz.*(temp_F.^2),'all') - mean_F^2);
    mean_Y = mean_F/(1 - rho_Y);
    std_Y = std_F/sqrt(1- rho_Y^2); % note: this is true for the types of F that we consider
    
    Y_lb = mean_Y - max(4*std_Y,0.001);
    Y_ub = mean_Y + max(4*std_Y,0.001);
%     Y_lb = mean_Y - 4*std_Y;
%     Y_ub = mean_Y + 4*std_Y;
    grid_Y = linspace(Y_lb,Y_ub,NY)';

    dN = parms.grid_N(2) - parms.grid_N(1);
    dx = diff(parms.grid_x);
    dY = grid_Y(2) - grid_Y(1);
    if any(abs(diff(parms.grid_N) - dN)>1e-10)
        error('grid for N should be evenly spaced');
    end
    if any(dx<=0)
        error('grid for x should be increasing');
    end

    x_lb = parms.grid_x(1); N_lb = parms.grid_N(1);
    
    % dynamics for N'
    Nnext_Nxz = gen_law_of_motion_N(parms,gesol);
    Nnext_NxYzz = permute(repmat(Nnext_Nxz,[1 1 1 NY Nz]),[1 2 4 3 5]);
    
    % dynamics for lambda' = lambda'(N,lambda,z,z')
    x_next_Nxzz = reshape(gen_law_of_motion_x(parms),[NN Nx Nz Nz]);
    x_next_NxYzz = permute(repmat(x_next_Nxzz,[1 1 1 1 NY]),[1 2 5 3 4]);
    
    % dynamics for Y'=rho_Y*Y + F(N,x,z,z')
    Y_next_NxYzz = rho_Y*permute(repmat(grid_Y(:),[1 NN Nx Nz Nz]),[2 3 1 4 5]) ...
        + permute(repmat(F_Nxzz,[1 1 1 1 NY]),[1 2 5 3 4]);
    Y_next_NxYzz(Y_next_NxYzz<Y_lb) = Y_lb;
    Y_next_NxYzz(Y_next_NxYzz>Y_ub) = Y_ub;
    
    % precompute update weights
    Nnext_NxYzz_idx1 = min(floor((Nnext_NxYzz - N_lb)/dN) + 1, NN-1);
    Nnext_NxYzz_idx2 = Nnext_NxYzz_idx1 + 1;
    wNnext_NxYzz_idx1 = (parms.grid_N(Nnext_NxYzz_idx2) - Nnext_NxYzz)/dN;
    wNnext_NxYzz_idx2 = 1 - wNnext_NxYzz_idx1;
    
    x_next_NxYzz_idx1 = arrayfun(@(X)find(parms.grid_x(1:end-1)<=X,1,'last'),x_next_NxYzz);
    x_next_NxYzz_idx2 = x_next_NxYzz_idx1 + 1;
    wx_next_NLYzz_idx1 = (parms.grid_x(x_next_NxYzz_idx2) - x_next_NxYzz)./dx(x_next_NxYzz_idx1);
    wx_next_NLYzz_idx2 = 1 - wx_next_NLYzz_idx1;
    
    Y_next_NLYzz_idx1 = min(floor((Y_next_NxYzz - Y_lb)/dY) + 1, NY-1);
    Y_next_NLYzz_idx2 = Y_next_NLYzz_idx1 + 1;
    wY_next_NLYzz_idx1 = (grid_Y(Y_next_NLYzz_idx2) - Y_next_NxYzz)/dY;
    wY_next_NLYzz_idx2 = 1 - wY_next_NLYzz_idx1;
    
    Phi_idx_from = [];
    Phi_idx_to = [];
    Phi_val = [];
    
    % reshape inputs
    Nnext_NxYzz_idx1 = reshape(Nnext_NxYzz_idx1,[NN*Nx*NY, Nz, Nz]);
    Nnext_NxYzz_idx2 = reshape(Nnext_NxYzz_idx2,[NN*Nx*NY, Nz, Nz]);
    wNnext_NxYzz_idx1 = reshape(wNnext_NxYzz_idx1,[NN*Nx*NY, Nz, Nz]);
    wNnext_NxYzz_idx2 = reshape(wNnext_NxYzz_idx2,[NN*Nx*NY, Nz, Nz]);
    
    x_next_NxYzz_idx1 = reshape(x_next_NxYzz_idx1,[NN*Nx*NY, Nz, Nz]);
    x_next_NxYzz_idx2 = reshape(x_next_NxYzz_idx2,[NN*Nx*NY, Nz, Nz]);
    wx_next_NLYzz_idx1 = reshape(wx_next_NLYzz_idx1,[NN*Nx*NY, Nz, Nz]);
    wx_next_NLYzz_idx2 = reshape(wx_next_NLYzz_idx2,[NN*Nx*NY, Nz, Nz]);
    
    Y_next_NLYzz_idx1 = reshape(Y_next_NLYzz_idx1,[NN*Nx*NY, Nz, Nz]);
    Y_next_NLYzz_idx2 = reshape(Y_next_NLYzz_idx2,[NN*Nx*NY, Nz, Nz]);
    wY_next_NLYzz_idx1 = reshape(wY_next_NLYzz_idx1,[NN*Nx*NY, Nz, Nz]);
    wY_next_NLYzz_idx2 = reshape(wY_next_NLYzz_idx2,[NN*Nx*NY, Nz, Nz]);
    
    % convert into NLY format
    NxYnext_idx111 = sub2ind_3d_fast([NN Nx NY], Nnext_NxYzz_idx1, x_next_NxYzz_idx1, Y_next_NLYzz_idx1);
    NxYnext_idx112 = sub2ind_3d_fast([NN Nx NY], Nnext_NxYzz_idx1, x_next_NxYzz_idx1, Y_next_NLYzz_idx2);
    NxYnext_idx121 = sub2ind_3d_fast([NN Nx NY], Nnext_NxYzz_idx1, x_next_NxYzz_idx2, Y_next_NLYzz_idx1);
    NxYnext_idx122 = sub2ind_3d_fast([NN Nx NY], Nnext_NxYzz_idx1, x_next_NxYzz_idx2, Y_next_NLYzz_idx2);
    NxYnext_idx211 = sub2ind_3d_fast([NN Nx NY], Nnext_NxYzz_idx2, x_next_NxYzz_idx1, Y_next_NLYzz_idx1);
    NxYnext_idx212 = sub2ind_3d_fast([NN Nx NY], Nnext_NxYzz_idx2, x_next_NxYzz_idx1, Y_next_NLYzz_idx2);
    NxYnext_idx221 = sub2ind_3d_fast([NN Nx NY], Nnext_NxYzz_idx2, x_next_NxYzz_idx2, Y_next_NLYzz_idx1);
    NxYnext_idx222 = sub2ind_3d_fast([NN Nx NY], Nnext_NxYzz_idx2, x_next_NxYzz_idx2, Y_next_NLYzz_idx2);
    
    wNxYnext_idx111 = wNnext_NxYzz_idx1.*wx_next_NLYzz_idx1.*wY_next_NLYzz_idx1;
    wNxYnext_idx112 = wNnext_NxYzz_idx1.*wx_next_NLYzz_idx1.*wY_next_NLYzz_idx2;
    wNxYnext_idx121 = wNnext_NxYzz_idx1.*wx_next_NLYzz_idx2.*wY_next_NLYzz_idx1;
    wNxYnext_idx122 = wNnext_NxYzz_idx1.*wx_next_NLYzz_idx2.*wY_next_NLYzz_idx2;
    wNxYnext_idx211 = wNnext_NxYzz_idx2.*wx_next_NLYzz_idx1.*wY_next_NLYzz_idx1;
    wNxYnext_idx212 = wNnext_NxYzz_idx2.*wx_next_NLYzz_idx1.*wY_next_NLYzz_idx2;
    wNxYnext_idx221 = wNnext_NxYzz_idx2.*wx_next_NLYzz_idx2.*wY_next_NLYzz_idx1;
    wNxYnext_idx222 = wNnext_NxYzz_idx2.*wx_next_NLYzz_idx2.*wY_next_NLYzz_idx2;
    
    IDX_N_repmat = repmat((1:NN)',[1 Nx NY]);
    IDX_x_repmat = permute(repmat((1:Nx)',[1 NN NY]),[2 1 3]);
    IDX_Y_repmat = permute(repmat((1:NY)',[1 NN Nx]),[2 3 1]);
    IDX_NxYnow = sub2ind_3d_fast([NN Nx NY],IDX_N_repmat(:),IDX_x_repmat(:), IDX_Y_repmat(:));
    
    idx_from = zeros(NN*Nx*NY,1);
    idx_to111 = zeros(NN*Nx*NY,1);
    idx_to112 = zeros(NN*Nx*NY,1);
    idx_to121 = zeros(NN*Nx*NY,1);
    idx_to122 = zeros(NN*Nx*NY,1);
    idx_to211 = zeros(NN*Nx*NY,1);
    idx_to212 = zeros(NN*Nx*NY,1);
    idx_to221 = zeros(NN*Nx*NY,1);
    idx_to222 = zeros(NN*Nx*NY,1);
    
    for iznow = 1:Nz
        idx_from(:) = sub2ind_2d_fast([NN*Nx*NY, Nz],IDX_NxYnow, iznow);
        for iznext = 1:Nz
            idx_to111(:) = sub2ind_2d_fast([NN*Nx*NY, Nz],NxYnext_idx111(:,iznow,iznext),iznext);
            idx_to112(:) = sub2ind_2d_fast([NN*Nx*NY, Nz],NxYnext_idx112(:,iznow,iznext),iznext);
            idx_to121(:) = sub2ind_2d_fast([NN*Nx*NY, Nz],NxYnext_idx121(:,iznow,iznext),iznext);
            idx_to122(:) = sub2ind_2d_fast([NN*Nx*NY, Nz],NxYnext_idx122(:,iznow,iznext),iznext);
            idx_to211(:) = sub2ind_2d_fast([NN*Nx*NY, Nz],NxYnext_idx211(:,iznow,iznext),iznext);
            idx_to212(:) = sub2ind_2d_fast([NN*Nx*NY, Nz],NxYnext_idx212(:,iznow,iznext),iznext);
            idx_to221(:) = sub2ind_2d_fast([NN*Nx*NY, Nz],NxYnext_idx221(:,iznow,iznext),iznext);
            idx_to222(:) = sub2ind_2d_fast([NN*Nx*NY, Nz],NxYnext_idx222(:,iznow,iznext),iznext);
            Phi_idx_from = [Phi_idx_from;idx_from;idx_from;idx_from;idx_from;...
                                         idx_from;idx_from;idx_from;idx_from];
            Phi_idx_to = [Phi_idx_to;idx_to111;idx_to112;idx_to121;idx_to122;...
                                     idx_to211;idx_to212;idx_to221;idx_to222];
            Phi_val = [Phi_val;parms.StateTransitionProbs(iznow,iznext)*wNxYnext_idx111(:,iznow,iznext); ...
                parms.StateTransitionProbs(iznow,iznext)*wNxYnext_idx112(:,iznow,iznext); ...
                parms.StateTransitionProbs(iznow,iznext)*wNxYnext_idx121(:,iznow,iznext); ...
                parms.StateTransitionProbs(iznow,iznext)*wNxYnext_idx122(:,iznow,iznext); ...
                parms.StateTransitionProbs(iznow,iznext)*wNxYnext_idx211(:,iznow,iznext); ...
                parms.StateTransitionProbs(iznow,iznext)*wNxYnext_idx212(:,iznow,iznext); ...
                parms.StateTransitionProbs(iznow,iznext)*wNxYnext_idx221(:,iznow,iznext); ...
                parms.StateTransitionProbs(iznow,iznext)*wNxYnext_idx222(:,iznow,iznext)];
        end
    end
    
    Phi_NxYz = sparse(Phi_idx_from,Phi_idx_to,Phi_val,NN*Nx*NY*Nz,NN*Nx*NY*Nz);
    
    if any(Phi_NxYz(:)<0) || any(abs(1 - sum(Phi_NxYz,2))>1e-10)
        error('invalid transition matrix')
    end
    
    if nargout==3
        % compute stationary probability
        [V,~,FLAG] = eigs(Phi_NxYz',1);
        if FLAG~=0; error('Failed to find stationary distribution'); end
        prob_NxYz = reshape(V(:,1)./sum(V(:,1)),[NN Nx NY Nz]);
        varargout(1) = {prob_NxYz};
    elseif nargout>3
        error('At most three outputs.');
    end

end

function ind = sub2ind_2d_fast(sz, i1, i2)
    
    ind = i1 + (i2-1)*sz(1);

end

function ind = sub2ind_3d_fast(sz, i1, i2, i3)

    ind = i1 + (i2-1)*sz(1) + (i3-1)*sz(1)*sz(2);

end